// Copyright 2021 The Chromium Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include <memory>

#include "base/callback.h"
#include "base/files/file_path.h"
#include "base/memory/raw_ptr.h"
#include "base/memory/scoped_refptr.h"
#include "base/memory/weak_ptr.h"
#include "base/task/post_task.h"
#include "base/test/bind.h"
#include "chrome/browser/chrome_content_browser_client.h"
#include "chrome/browser/extensions/extension_browsertest.h"
#include "chrome/browser/extensions/extension_tab_util.h"
#include "chrome/browser/ui/browser.h"
#include "chrome/browser/ui/tabs/tab_strip_model.h"
#include "chrome/test/base/ui_test_utils.h"
#include "content/public/browser/browser_message_filter.h"
#include "content/public/browser/browser_task_traits.h"
#include "content/public/browser/browser_thread.h"
#include "content/public/browser/render_frame_host.h"
#include "content/public/browser/render_process_host.h"
#include "content/public/browser/web_contents.h"
#include "content/public/common/content_client.h"
#include "content/public/test/browser_test.h"
#include "content/public/test/browser_test_utils.h"
#include "extensions/browser/api/storage/storage_api.h"
#include "extensions/browser/bad_message.h"
#include "extensions/browser/browsertest_util.h"
#include "extensions/common/constants.h"
#include "extensions/common/extension_messages.h"
#include "extensions/test/test_extension_dir.h"
#include "ipc/ipc_security_test_util.h"
#include "net/dns/mock_host_resolver.h"
#include "net/test/embedded_test_server/embedded_test_server.h"
#include "testing/gmock/include/gmock/gmock.h"
#include "testing/gtest/include/gtest/gtest.h"
#include "third_party/blink/public/mojom/service_worker/service_worker_database.mojom-forward.h"
#include "url/gurl.h"

namespace extensions {

// Waits for a kill of the given RenderProcessHost and returns the
// BadMessageReason that caused an //extensions-triggerred kill.
//
// Example usage:
//   RenderProcessHostBadIpcMessageWaiter kill_waiter(render_process_host);
//   ... test code that triggers a renderer kill ...
//   EXPECT_EQ(bad_message::EFD_BAD_MESSAGE_PROCESS, kill_waiter.Wait());
class RenderProcessHostBadIpcMessageWaiter {
 public:
  explicit RenderProcessHostBadIpcMessageWaiter(
      content::RenderProcessHost* render_process_host)
      : internal_waiter_(render_process_host,
                         "Stability.BadMessageTerminated.Extensions") {}

  // Waits until the renderer process exits.  Returns the bad message that made
  // //extensions kill the renderer.  `absl::nullopt` is returned if the
  // renderer was killed outside of //extensions or exited normally.
  [[nodiscard]] absl::optional<bad_message::BadMessageReason> Wait() {
    absl::optional<int> internal_result = internal_waiter_.Wait();
    if (!internal_result.has_value())
      return absl::nullopt;
    return static_cast<bad_message::BadMessageReason>(internal_result.value());
  }

  RenderProcessHostBadIpcMessageWaiter(
      const RenderProcessHostBadIpcMessageWaiter&) = delete;
  RenderProcessHostBadIpcMessageWaiter& operator=(
      const RenderProcessHostBadIpcMessageWaiter&) = delete;

 private:
  content::RenderProcessHostKillWaiter internal_waiter_;
};

// Intercepts legacy IPC messages of type TMessageType. Only meant for
// intercepting message once.
template <typename TMessageType>
class ExtensionMessageWaiter {
 public:
  using MessageParamType = typename TMessageType::Param;

  ExtensionMessageWaiter() : weak_ptr_factory_(this) {
    test_content_browser_client_ = std::make_unique<TestContentBrowserClient>(
        weak_ptr_factory_.GetWeakPtr());
  }

  ~ExtensionMessageWaiter() = default;

  ExtensionMessageWaiter(const ExtensionMessageWaiter&) = delete;
  ExtensionMessageWaiter& operator=(const ExtensionMessageWaiter&) = delete;

  using IpcMatcher =
      base::RepeatingCallback<bool(int captured_render_process_id,
                                   const MessageParamType& param)>;
  void SetIpcMatcher(IpcMatcher ipc_matcher) { ipc_matcher_ = ipc_matcher; }

  MessageParamType WaitForMessage() {
    DCHECK_CURRENTLY_ON(content::BrowserThread::UI);

    // Wait for `captured_message_param_` in a nested message loop if needed
    // (i.e. if CaptureMessageParam hasn't called `run_loop_.Quit()` yet).
    run_loop_.Run();

    // Return the `captured_message_param_`.
    DCHECK(captured_message_param_.has_value());
    return *captured_message_param_;
  }

 private:
  void CaptureMessageParam(MessageParamType param, int render_process_id) {
    DCHECK_CURRENTLY_ON(content::BrowserThread::UI);

    // Do nothing if we already captured a matching IPC.
    if (captured_message_param_.has_value())
      return;

    // Do nothing if the IPC doesn't match.
    if (ipc_matcher_.is_null() || !ipc_matcher_.Run(render_process_id, param))
      return;

    // Capture the IPC payload.
    captured_message_param_.emplace(std::move(param));
    captured_render_process_id_ = render_process_id;

    // Once we have `captured_message_param_` there is no need to inject
    // TestFilter into additional RenderProcessHosts.
    test_content_browser_client_.reset();

    // Wake up WaitForMessage if necessary.
    run_loop_.Quit();
  }

  // A BrowserMessageFilter implementation that posts a copy of the payload of
  // the TMessageType into ExtensionMessageWaiter::CaptureMessageParam, but
  // otherwise leaves all the message unhandled (i.e. allows other filters to
  // process the message).
  class TestFilter : public content::BrowserMessageFilter {
   public:
    TestFilter(base::WeakPtr<ExtensionMessageWaiter> ipc_message_waiter,
               int render_process_id)
        : content::BrowserMessageFilter(ExtensionMsgStart),
          ipc_message_waiter_(ipc_message_waiter),
          render_process_id_(render_process_id) {}

   private:
    // content::BrowserMessageFilter overrides:
    bool OnMessageReceived(const IPC::Message& message) override {
      DCHECK_CURRENTLY_ON(content::BrowserThread::IO);

      MessageParamType param;
      if (message.type() == TMessageType::ID &&
          TMessageType::Read(&message, &param)) {
        base::PostTask(
            FROM_HERE, {content::BrowserThread::UI},
            base::BindOnce(&ExtensionMessageWaiter::CaptureMessageParam,
                           ipc_message_waiter_, std::move(param),
                           render_process_id_));
      }

      return false;  // Not handled - let another filter handle the message.
    }

    ~TestFilter() override = default;

    base::WeakPtr<ExtensionMessageWaiter> ipc_message_waiter_;
    const int render_process_id_;
  };

  // A content::ContentBrowserClient that injects a TestFilter (as the very
  // first filter) into all new RenderProcessHost objects, but otherwise behaves
  // identically to ChromeContentBrowserClient.
  class TestContentBrowserClient : public ChromeContentBrowserClient {
   public:
    explicit TestContentBrowserClient(
        base::WeakPtr<ExtensionMessageWaiter> ipc_message_waiter)
        : ipc_message_waiter_(ipc_message_waiter) {
      old_client_ = content::SetBrowserClientForTesting(this);
    }

    ~TestContentBrowserClient() override {
      content::SetBrowserClientForTesting(old_client_);
    }

   private:
    // content::ContentBrowserClient overrides:
    void RenderProcessWillLaunch(content::RenderProcessHost* host) override {
      auto test_filter =
          base::MakeRefCounted<TestFilter>(ipc_message_waiter_, host->GetID());
      host->AddFilter(test_filter.get());

      ChromeContentBrowserClient::RenderProcessWillLaunch(host);
    }

    base::WeakPtr<ExtensionMessageWaiter> ipc_message_waiter_;
    raw_ptr<content::ContentBrowserClient> old_client_;
  };

  std::unique_ptr<TestContentBrowserClient> test_content_browser_client_;
  base::RunLoop run_loop_;
  IpcMatcher ipc_matcher_;
  absl::optional<MessageParamType> captured_message_param_;
  int captured_render_process_id_ = -1;

  base::WeakPtrFactory<ExtensionMessageWaiter> weak_ptr_factory_;
};

// Test suite covering how mojo/IPC messages are verified after being received
// from a (potentially compromised) renderer process.
class OpenChannelToExtensionExploitTest : public ExtensionBrowserTest {
 public:
  OpenChannelToExtensionExploitTest() = default;

  void SetUpOnMainThread() override {
    ExtensionBrowserTest::SetUpOnMainThread();

    host_resolver()->AddRule("*", "127.0.0.1");
    content::SetupCrossSiteRedirector(embedded_test_server());
    ASSERT_TRUE(embedded_test_server()->Start());

    InstallTestExtensions();

    GURL test_page_url =
        embedded_test_server()->GetURL("foo.com", "/title1.html");
    ipc_message_waiter_ = StartInterceptingIpcs(test_page_url);
  }

  // Waits for ExtensionHostMsg_OpenChannelToExtension IPC and returns its
  // payload.
  ExtensionHostMsg_OpenChannelToExtension::Param WaitForMessage() {
    return ipc_message_waiter_->WaitForMessage();
  }

  // Asks the `extension_id` to inject `content_script` into `web_contents`.
  // Returns true if the content script execution started successfully.
  bool ExecuteProgrammaticContentScript(content::WebContents* web_contents,
                                        const ExtensionId& extension_id,
                                        const std::string& content_script) {
    DCHECK(web_contents);
    int tab_id = ExtensionTabUtil::GetTabId(web_contents);
    std::string background_script = content::JsReplace(
        "chrome.tabs.executeScript($1, { code: $2 });", tab_id, content_script);
    return browsertest_util::ExecuteScriptInBackgroundPageNoWait(
        browser()->profile(), extension_id, background_script);
  }

  const ExtensionId& active_extension_id() { return active_extension_->id(); }

  const ExtensionId& spoofed_extension_id() { return spoofed_extension_->id(); }

 private:
  void InstallTestExtensions() {
    // Install an `active_extension` and a separate, but otherwise identical
    // `spoofed_extension` (the only difference will be the extension id).
    auto install_extension = [this](TestExtensionDir& dir) -> const Extension* {
      const char kManifestTemplate[] = R"(
          {
            "name": "ContentScriptTrackerBrowserTest - Programmatic",
            "version": "1.0",
            "manifest_version": 2,
            "permissions": [ "tabs", "<all_urls>" ],
            "background": {"scripts": ["background_script.js"]}
          } )";
      dir.WriteManifest(kManifestTemplate);
      dir.WriteFile(FILE_PATH_LITERAL("background_script.js"), "");
      return LoadExtension(dir.UnpackedPath());
    };
    TestExtensionDir active_dir;
    TestExtensionDir spoofed_dir;
    active_extension_ = install_extension(active_dir);
    spoofed_extension_ = install_extension(spoofed_dir);
    ASSERT_TRUE(active_extension_);
    ASSERT_TRUE(spoofed_extension_);
    ASSERT_NE(active_extension_id(), spoofed_extension_id());
  }

  using OpenChannelMessageWaiter =
      ExtensionMessageWaiter<ExtensionHostMsg_OpenChannelToExtension>;
  std::unique_ptr<OpenChannelMessageWaiter> StartInterceptingIpcs(
      const GURL& test_page_url) {
    // Start capturing IPC messages in all future/new RenderProcessHosts.
    auto ipc_message_waiter = std::make_unique<OpenChannelMessageWaiter>();

    // Navigate to an arbitrary, mostly empty test page.  Make sure that a new
    // RenderProcessHost is created to make sure it is covered by the
    // `ipc_message_waiter`.  (A WebUI -> http navigation should swap the
    // RenderProcessHost on all platforms.)
    content::WebContents* web_contents =
        browser()->tab_strip_model()->GetActiveWebContents();
    int old_process_id = web_contents->GetMainFrame()->GetProcess()->GetID();
    EXPECT_TRUE(
        ui_test_utils::NavigateToURL(browser(), GURL("chrome://version")));
    EXPECT_TRUE(ui_test_utils::NavigateToURL(browser(), test_page_url));
    int new_process_id = web_contents->GetMainFrame()->GetProcess()->GetID();
    EXPECT_NE(old_process_id, new_process_id);

    // Only intercept messages from `active_extension`'s content script running
    // in the main frame's process.
    std::string matching_extension_id = active_extension_id();
    int matching_process_id = new_process_id;
    ipc_message_waiter->SetIpcMatcher(base::BindLambdaForTesting(
        [matching_extension_id, matching_process_id](
            int captured_render_process_id,
            const ExtensionHostMsg_OpenChannelToExtension::Param& param) {
          if (captured_render_process_id != matching_process_id)
            return false;

          extensions::PortContext source_context;
          ExtensionMsg_ExternalConnectionInfo info;
          std::string channel_name;
          extensions::PortId port_id;
          std::tie(source_context, info, channel_name, port_id) = param;

          if (info.source_endpoint.extension_id != matching_extension_id)
            return false;

          if (info.source_endpoint.type != MessagingEndpoint::Type::kTab)
            return false;

          return true;
        }));

    return ipc_message_waiter;
  }

  std::unique_ptr<OpenChannelMessageWaiter> ipc_message_waiter_;

  raw_ptr<const Extension> active_extension_ = nullptr;
  raw_ptr<const Extension> spoofed_extension_ = nullptr;
};

IN_PROC_BROWSER_TEST_F(OpenChannelToExtensionExploitTest,
                       FromContentScript_BadExtensionIdInMessagingSource) {
  // Trigger sending of a valid ExtensionHostMsg_OpenChannelToExtension IPC
  // from a content script of an `active_extension_id`.
  content::WebContents* web_contents =
      browser()->tab_strip_model()->GetActiveWebContents();
  ASSERT_TRUE(ExecuteProgrammaticContentScript(
      web_contents, active_extension_id(),
      "chrome.runtime.sendMessage({greeting: 'hello'}, (response) => {});"));

  // Capture the IPC.
  extensions::PortContext source_context;
  ExtensionMsg_ExternalConnectionInfo info;
  std::string channel_name;
  extensions::PortId port_id;
  std::tie(source_context, info, channel_name, port_id) = WaitForMessage();

  // Mutate the IPC payload.
  EXPECT_EQ(MessagingEndpoint::Type::kTab, info.source_endpoint.type);
  EXPECT_EQ(active_extension_id(), info.source_endpoint.extension_id);
  info.source_endpoint.extension_id = spoofed_extension_id();

  // Inject the malformed/mutated IPC and verify that the renderer is terminated
  // as expected.
  content::RenderProcessHost* main_frame_process =
      web_contents->GetMainFrame()->GetProcess();
  RenderProcessHostBadIpcMessageWaiter kill_waiter(main_frame_process);
  IPC::IpcSecurityTestUtil::PwnMessageReceived(
      main_frame_process->GetChannel(),
      ExtensionHostMsg_OpenChannelToExtension(source_context, info,
                                              channel_name, port_id));
  EXPECT_EQ(bad_message::EMF_INVALID_EXTENSION_ID_FOR_CONTENT_SCRIPT,
            kill_waiter.Wait());
}

IN_PROC_BROWSER_TEST_F(OpenChannelToExtensionExploitTest,
                       FromContentScript_UnexpectedNativeAppType) {
  // Trigger sending of a valid ExtensionHostMsg_OpenChannelToExtension IPC
  // from a content script of an `active_extension_id`.
  content::WebContents* web_contents =
      browser()->tab_strip_model()->GetActiveWebContents();
  ASSERT_TRUE(ExecuteProgrammaticContentScript(
      web_contents, active_extension_id(),
      "chrome.runtime.sendMessage({greeting: 'hello'}, (response) => {});"));

  // Capture the IPC.
  extensions::PortContext source_context;
  ExtensionMsg_ExternalConnectionInfo info;
  std::string channel_name;
  extensions::PortId port_id;
  std::tie(source_context, info, channel_name, port_id) = WaitForMessage();

  // Mutate the IPC payload.
  EXPECT_EQ(MessagingEndpoint::Type::kTab, info.source_endpoint.type);
  EXPECT_EQ(active_extension_id(), info.source_endpoint.extension_id);
  info.source_endpoint.type = MessagingEndpoint::Type::kNativeApp;

  // Inject the malformed/mutated IPC and verify that the renderer is terminated
  // as expected.
  content::RenderProcessHost* main_frame_process =
      web_contents->GetMainFrame()->GetProcess();
  RenderProcessHostBadIpcMessageWaiter kill_waiter(main_frame_process);
  IPC::IpcSecurityTestUtil::PwnMessageReceived(
      main_frame_process->GetChannel(),
      ExtensionHostMsg_OpenChannelToExtension(source_context, info,
                                              channel_name, port_id));
  EXPECT_EQ(bad_message::EMF_INVALID_CHANNEL_SOURCE_TYPE, kill_waiter.Wait());
}

IN_PROC_BROWSER_TEST_F(OpenChannelToExtensionExploitTest,
                       FromContentScript_UnexpectedExtensionType) {
  // Trigger sending of a valid ExtensionHostMsg_OpenChannelToExtension IPC
  // from a content script of an `active_extension_id`.
  content::WebContents* web_contents =
      browser()->tab_strip_model()->GetActiveWebContents();
  ASSERT_TRUE(ExecuteProgrammaticContentScript(
      web_contents, active_extension_id(),
      "chrome.runtime.sendMessage({greeting: 'hello'}, (response) => {});"));

  // Capture the IPC.
  extensions::PortContext source_context;
  ExtensionMsg_ExternalConnectionInfo info;
  std::string channel_name;
  extensions::PortId port_id;
  std::tie(source_context, info, channel_name, port_id) = WaitForMessage();

  // Mutate the IPC payload.
  EXPECT_EQ(MessagingEndpoint::Type::kTab, info.source_endpoint.type);
  EXPECT_EQ(active_extension_id(), info.source_endpoint.extension_id);
  info.source_endpoint.type = MessagingEndpoint::Type::kExtension;

  // Inject the malformed/mutated IPC and verify that the renderer is terminated
  // as expected.
  content::RenderProcessHost* main_frame_process =
      web_contents->GetMainFrame()->GetProcess();
  RenderProcessHostBadIpcMessageWaiter kill_waiter(main_frame_process);
  IPC::IpcSecurityTestUtil::PwnMessageReceived(
      main_frame_process->GetChannel(),
      ExtensionHostMsg_OpenChannelToExtension(source_context, info,
                                              channel_name, port_id));
  EXPECT_EQ(bad_message::EMF_INVALID_EXTENSION_ID_FOR_EXTENSION_SOURCE,
            kill_waiter.Wait());
}

IN_PROC_BROWSER_TEST_F(OpenChannelToExtensionExploitTest,
                       FromContentScript_NoExtensionIdForExtensionType) {
  // Trigger sending of a valid ExtensionHostMsg_OpenChannelToExtension IPC
  // from a content script of an `active_extension_id`.
  content::WebContents* web_contents =
      browser()->tab_strip_model()->GetActiveWebContents();
  ASSERT_TRUE(ExecuteProgrammaticContentScript(
      web_contents, active_extension_id(),
      "chrome.runtime.sendMessage({greeting: 'hello'}, (response) => {});"));

  // Capture the IPC.
  extensions::PortContext source_context;
  ExtensionMsg_ExternalConnectionInfo info;
  std::string channel_name;
  extensions::PortId port_id;
  std::tie(source_context, info, channel_name, port_id) = WaitForMessage();

  // Mutate the IPC payload.
  EXPECT_EQ(MessagingEndpoint::Type::kTab, info.source_endpoint.type);
  EXPECT_EQ(active_extension_id(), info.source_endpoint.extension_id);
  info.source_endpoint.type = MessagingEndpoint::Type::kExtension;
  info.source_endpoint.extension_id = absl::nullopt;

  // Inject the malformed/mutated IPC and verify that the renderer is terminated
  // as expected.
  content::RenderProcessHost* main_frame_process =
      web_contents->GetMainFrame()->GetProcess();
  RenderProcessHostBadIpcMessageWaiter kill_waiter(main_frame_process);
  IPC::IpcSecurityTestUtil::PwnMessageReceived(
      main_frame_process->GetChannel(),
      ExtensionHostMsg_OpenChannelToExtension(source_context, info,
                                              channel_name, port_id));
  EXPECT_EQ(bad_message::EMF_NO_EXTENSION_ID_FOR_EXTENSION_SOURCE,
            kill_waiter.Wait());
}

IN_PROC_BROWSER_TEST_F(OpenChannelToExtensionExploitTest,
                       FromContentScript_UnexpectedWorkerContext) {
  // Trigger sending of a valid ExtensionHostMsg_OpenChannelToExtension IPC
  // from a content script of an `active_extension_id`.
  content::WebContents* web_contents =
      browser()->tab_strip_model()->GetActiveWebContents();
  ASSERT_TRUE(ExecuteProgrammaticContentScript(
      web_contents, active_extension_id(),
      "chrome.runtime.sendMessage({greeting: 'hello'}, (response) => {});"));

  // Capture the IPC.
  extensions::PortContext source_context;
  ExtensionMsg_ExternalConnectionInfo info;
  std::string channel_name;
  extensions::PortId port_id;
  std::tie(source_context, info, channel_name, port_id) = WaitForMessage();

  // Mutate the IPC payload.
  EXPECT_TRUE(source_context.is_for_render_frame());
  EXPECT_FALSE(source_context.is_for_service_worker());
  source_context.frame = absl::nullopt;
  source_context.worker = PortContext::WorkerContext(
      /* thread_id = */ 123, /* version_id = */ 456,
      /* extension_id = */ active_extension_id());

  // Inject the malformed/mutated IPC and verify that the renderer is terminated
  // as expected.
  content::RenderProcessHost* main_frame_process =
      web_contents->GetMainFrame()->GetProcess();
  RenderProcessHostBadIpcMessageWaiter kill_waiter(main_frame_process);
  IPC::IpcSecurityTestUtil::PwnMessageReceived(
      main_frame_process->GetChannel(),
      ExtensionHostMsg_OpenChannelToExtension(source_context, info,
                                              channel_name, port_id));
  EXPECT_EQ(bad_message::EMF_INVALID_EXTENSION_ID_FOR_WORKER_CONTEXT,
            kill_waiter.Wait());
}

}  // namespace extensions
